# encoding: UTF-8

"""
一个ATR-RSI指标结合的交易策略，适合用在股指的1分钟和5分钟线上。

注意事项：
1. 作者不对交易盈利做任何保证，策略代码仅供参考
2. 本策略需要用到talib，没有安装的用户请先参考www.vnpy.org上的教程安装
3. 将IF0000_1min.csv用ctaHistoryData.py导入MongoDB后，直接运行本文件即可回测策略

"""


from ctaBase import *
from ctaTemplate import CtaTemplate

import talib
import numpy as np


########################################################################
class AtrRsiStrategy(CtaTemplate):
	"""结合ATR和RSI指标的一个分钟线交易策略"""
	className = 'AtrRsiStrategy'
	author = u'用Python的交易员'

	# 策略参数
	atrLength = 22		  # 计算ATR指标的窗口数
	atrMaLength = 10		# 计算ATR均线的窗口数
	rsiLength = 5		   # 计算RSI的窗口数
	rsiEntry = 16		   # RSI的开仓信号
	trailingPercent = 0.8   # 百分比移动止损
	initDays = 10		   # 初始化数据所用的天数

	# 策略变量
	bar = None				  # K线对象
	barMinute = EMPTY_STRING	# K线当前的分钟

	bufferSize = 100					# 需要缓存的数据的大小
	bufferCount = 0					 # 目前已经缓存了的数据的计数
	highArray = np.zeros(bufferSize)	# K线最高价的数组
	lowArray = np.zeros(bufferSize)	 # K线最低价的数组
	closeArray = np.zeros(bufferSize)   # K线收盘价的数组
	
	atrCount = 0						# 目前已经缓存了的ATR的计数
	atrArray = np.zeros(bufferSize)	 # ATR指标的数组
	atrValue = 0						# 最新的ATR指标数值
	atrMa = 0						   # ATR移动平均的数值

	rsiValue = 0						# RSI指标的数值
	rsiBuy = 0						  # RSI买开阈值
	rsiSell = 0						 # RSI卖开阈值
	intraTradeHigh = 0				  # 移动止损用的持仓期内最高价
	intraTradeLow = 0				   # 移动止损用的持仓期内最低价

	orderList = []					  # 保存委托代码的列表

	# 参数列表，保存了参数的名称
	paramList = ['name',
				 'className',
				 'author',
				 'vtSymbol',
				 'atrLength',
				 'atrMaLength',
				 'rsiLength',
				 'rsiEntry',
				 'trailingPercent']

	# 变量列表，保存了变量的名称
	varList = ['inited',
			   'trading',
			   'pos',
			   'atrValue',
			   'atrMa',
			   'rsiValue',
			   'rsiBuy',
			   'rsiSell']

	#----------------------------------------------------------------------
	def __init__(self, ctaEngine, setting):
		"""Constructor"""
		super(AtrRsiStrategy, self).__init__(ctaEngine, setting)
		
		# 注意策略类中的可变对象属性（通常是list和dict等），在策略初始化时需要重新创建，
		# 否则会出现多个策略实例之间数据共享的情况，有可能导致潜在的策略逻辑错误风险，
		# 策略类中的这些可变对象属性可以选择不写，全都放在__init__下面，写主要是为了阅读
		# 策略时方便（更多是个编程习惯的选择）

	#----------------------------------------------------------------------
	def onInit(self):
		"""初始化策略（必须由用户继承实现）"""
		self.writeCtaLog(u'%s策略初始化' %self.name)
	
		# 初始化RSI入场阈值
		self.rsiBuy = 50 + self.rsiEntry
		self.rsiSell = 50 - self.rsiEntry

		# 载入历史数据，并采用回放计算的方式初始化策略数值
		initData = self.loadBar(self.initDays)
		for bar in initData:
			self.onBar(bar)

		self.putEvent()

	#----------------------------------------------------------------------
	def onStart(self):
		"""启动策略（必须由用户继承实现）"""
		self.writeCtaLog(u'%s策略启动' %self.name)
		self.putEvent()

	#----------------------------------------------------------------------
	def onStop(self):
		"""停止策略（必须由用户继承实现）"""
		self.writeCtaLog(u'%s策略停止' %self.name)
		self.putEvent()

	#----------------------------------------------------------------------
	def onTick(self, tick):
		"""收到行情TICK推送（必须由用户继承实现）"""
		# 计算K线
		tickMinute = tick.datetime.minute

		if tickMinute != self.barMinute:
			if self.bar:
				self.onBar(self.bar)

			bar = CtaBarData()
			bar.vtSymbol = tick.vtSymbol
			bar.symbol = tick.symbol
			bar.exchange = tick.exchange

			bar.open = tick.lastPrice
			bar.high = tick.lastPrice
			bar.low = tick.lastPrice
			bar.close = tick.lastPrice

			bar.date = tick.date
			bar.time = tick.time
			bar.datetime = tick.datetime	# K线的时间设为第一个Tick的时间

			self.bar = bar				  # 这种写法为了减少一层访问，加快速度
			self.barMinute = tickMinute	 # 更新当前的分钟
		else:							   # 否则继续累加新的K线
			bar = self.bar				  # 写法同样为了加快速度

			bar.high = max(bar.high, tick.lastPrice)
			bar.low = min(bar.low, tick.lastPrice)
			bar.close = tick.lastPrice

	#----------------------------------------------------------------------
	def onBar(self, bar):
		"""收到Bar推送（必须由用户继承实现）"""
		# 撤销之前发出的尚未成交的委托（包括限价单和停止单）
		for orderID in self.orderList:
			self.cancelOrder(orderID)
		self.orderList = []

		# 保存K线数据
		self.closeArray[0:self.bufferSize-1] = self.closeArray[1:self.bufferSize]
		self.highArray[0:self.bufferSize-1] = self.highArray[1:self.bufferSize]
		self.lowArray[0:self.bufferSize-1] = self.lowArray[1:self.bufferSize]
		
		self.closeArray[-1] = bar.close
		self.highArray[-1] = bar.high
		self.lowArray[-1] = bar.low
		
		self.bufferCount += 1
		if self.bufferCount < self.bufferSize:
			return

		# 计算指标数值
		self.atrValue = talib.ATR(self.highArray,
								  self.lowArray,
								  self.closeArray,
								  self.atrLength)[-1]
		self.atrArray[0:self.bufferSize-1] = self.atrArray[1:self.bufferSize]
		self.atrArray[-1] = self.atrValue

		self.atrCount += 1
		if self.atrCount < self.bufferSize:
			return

		self.atrMa = talib.MA(self.atrArray,
							  self.atrMaLength)[-1]
		self.rsiValue = talib.RSI(self.closeArray,
								  self.rsiLength)[-1]

		# 判断是否要进行交易
		
		# 当前无仓位
		if self.pos == 0:
			self.intraTradeHigh = bar.high
			self.intraTradeLow = bar.low

			# ATR数值上穿其移动平均线，说明行情短期内波动加大
			# 即处于趋势的概率较大，适合CTA开仓
			if self.atrValue > self.atrMa:
				# 使用RSI指标的趋势行情时，会在超买超卖区钝化特征，作为开仓信号
				if self.rsiValue > self.rsiBuy:
					# 这里为了保证成交，选择超价5个整指数点下单
					self.buy(bar.close+5, 1)

				elif self.rsiValue < self.rsiSell:
					self.short(bar.close-5, 1)

		# 持有多头仓位
		elif self.pos > 0:
			# 计算多头持有期内的最高价，以及重置最低价
			self.intraTradeHigh = max(self.intraTradeHigh, bar.high)
			self.intraTradeLow = bar.low
			# 计算多头移动止损
			longStop = self.intraTradeHigh * (1-self.trailingPercent/100)
			# 发出本地止损委托，并且把委托号记录下来，用于后续撤单
			orderID = self.sell(longStop, 1, stop=True)
			self.orderList.append(orderID)

		# 持有空头仓位
		elif self.pos < 0:
			self.intraTradeLow = min(self.intraTradeLow, bar.low)
			self.intraTradeHigh = bar.high

			shortStop = self.intraTradeLow * (1+self.trailingPercent/100)
			orderID = self.cover(shortStop, 1, stop=True)
			self.orderList.append(orderID)

		# 发出状态更新事件
		self.putEvent()

	#----------------------------------------------------------------------
	def onOrder(self, order):
		"""收到委托变化推送（必须由用户继承实现）"""
		pass

	#----------------------------------------------------------------------
	def onTrade(self, trade):
		pass


if __name__ == '__main__':
	# 提供直接双击回测的功能
	# 导入PyQt4的包是为了保证matplotlib使用PyQt4而不是PySide，防止初始化出错
	from ctaBacktesting import *
	from PyQt4 import QtCore, QtGui
	
	# 创建回测引擎
	engine = BacktestingEngine()
	
	# 设置引擎的回测模式为K线
	engine.setBacktestingMode(engine.BAR_MODE)

	# 设置回测用的数据起始日期
	engine.setStartDate('20120101')
	
	# 设置产品相关参数
	engine.setSlippage(0.2)	 # 股指1跳
	engine.setRate(0.3/10000)   # 万0.3
	engine.setSize(300)		 # 股指合约大小
	
	# 设置使用的历史数据库
	engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
	
	# 在引擎中创建策略对象
	d = {'atrLength': 11}
	engine.initStrategy(AtrRsiStrategy, d)
	
	# 开始跑回测
	engine.runBacktesting()
	
	# 显示回测结果
	engine.showBacktestingResult()
	
	## 跑优化
	#setting = OptimizationSetting()				 # 新建一个优化任务设置对象
	#setting.setOptimizeTarget('capital')			# 设置优化排序的目标是策略净盈利
	#setting.addParameter('atrLength', 11, 12, 1)	# 增加第一个优化参数atrLength，起始11，结束12，步进1
	#setting.addParameter('atrMa', 20, 30, 5)		# 增加第二个优化参数atrMa，起始20，结束30，步进1
	#engine.runOptimization(AtrRsiStrategy, setting) # 运行优化函数，自动输出结果
	
	